package com.zhentao.studyim.netty;

import com.alibaba.fastjson2.JSON;
import com.zhentao.studyim.entity.Message;
import com.zhentao.studyim.service.MessageService;
import com.zhentao.studyim.service.RedisService;
import com.zhentao.studyim.util.JwtUtil;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * WebSocket消息处理器
 *
 * 负责处理WebSocket连接的生命周期和消息传输：
 * - 连接建立和断开管理
 * - 用户身份认证
 * - 实时消息收发
 * - 心跳检测机制
 * - 消息频率限制
 * - 消息持久化存储
 * - Redis缓存集成
 *
 * 支持的消息类型：
 * - auth: 用户认证
 * - chat: 聊天消息
 * - heartbeat: 心跳检测
 *
 * @author zhentao
 * @version 1.0
 * @since 2025-01-22
 */
@Slf4j
@Component
@ChannelHandler.Sharable
public class WebSocketHandler extends SimpleChannelInboundHandler<TextWebSocketFrame> {

    // 存储用户ID和Channel的映射关系
    private static final Map<Long, ChannelHandlerContext> USER_CHANNELS = new ConcurrentHashMap<>();

    @Autowired
    private JwtUtil jwtUtil;

    @Autowired
    private MessageService messageService;

    @Autowired
    private RedisService redisService;

    /**
     * 连接建立时调用
     */
    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        log.debug("WebSocket连接建立: {}", ctx.channel().id());
        super.channelActive(ctx);
    }

    /**
     * 连接断开时调用
     */
    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        log.debug("WebSocket连接断开: {}", ctx.channel().id());

        // 从映射中移除断开的连接
        USER_CHANNELS.entrySet().removeIf(entry -> entry.getValue() == ctx);

        super.channelInactive(ctx);
    }

    /**
     * 接收到消息时调用
     */
    @Override
    protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame frame) throws Exception {
        String text = frame.text();
        log.debug("收到WebSocket消息: {}", text);

        try {
            // 解析JSON消息
            @SuppressWarnings("unchecked")
            Map<String, Object> message = JSON.parseObject(text, Map.class);
            String type = (String) message.get("type");

            switch (type) {
                case "auth":
                    handleAuth(ctx, message);
                    break;
                case "chat":
                    handleChat(ctx, message);
                    break;
                case "heartbeat":
                    handleHeartbeat(ctx);
                    break;
                default:
                    log.warn("未知消息类型: {}", type);
            }
        } catch (Exception e) {
            log.error("处理WebSocket消息失败", e);
            sendErrorMessage(ctx, "消息处理失败: " + e.getMessage());
        }
    }

    /**
     * 处理认证消息
     */
    private void handleAuth(ChannelHandlerContext ctx, Map<String, Object> message) {
        try {
            String token = (String) message.get("token");

            // 验证JWT令牌
            Long userId = jwtUtil.getUserIdFromToken(token);

            // 将用户ID和Channel关联
            USER_CHANNELS.put(userId, ctx);

            // 发送认证成功消息
            sendMessage(ctx, Map.of(
                    "type", "auth_success",
                    "message", "认证成功",
                    "userId", userId
            ));

            log.info("用户 {} 认证成功", userId);
        } catch (Exception e) {
            sendErrorMessage(ctx, "认证失败: " + e.getMessage());
        }
    }

    /**
     * 处理聊天消息
     */
    private void handleChat(ChannelHandlerContext ctx, Map<String, Object> message) {
        Long fromUserId = null;
        Long toUserId = null;
        String content = null;

        try {
            // 获取发送者ID
            fromUserId = getUserIdByChannel(ctx);
            if (fromUserId == null) {
                sendErrorMessage(ctx, "请先进行身份认证");
                return;
            }

            // 检查消息发送频率限制
            if (!redisService.checkMessageRateLimit(fromUserId)) {
                sendErrorMessage(ctx, "发送消息过于频繁，请稍后再试");
                return;
            }

            // 获取消息内容
            toUserId = Long.valueOf(message.get("toUserId").toString());
            content = (String) message.get("content");

            if (content == null || content.trim().isEmpty()) {
                sendErrorMessage(ctx, "消息内容不能为空");
                return;
            }

            log.info("开始处理聊天消息: {} -> {}, 内容: {}", fromUserId, toUserId, content);

            // 保存消息到数据库
            Message chatMessage = new Message();
            chatMessage.setFromUserId(fromUserId);
            chatMessage.setToUserId(toUserId);
            chatMessage.setContent(content);
            chatMessage.setType(Message.MessageType.TEXT);
            chatMessage.setSendTime(LocalDateTime.now());

            log.debug("准备保存消息到数据库...");
            Message savedMessage = messageService.saveMessage(chatMessage);
            log.info("消息已保存到数据库，ID: {}", savedMessage.getId());

            // 尝试缓存消息到Redis（失败不影响主流程）
            try {
                redisService.cacheRecentMessage(fromUserId, toUserId, content);
                redisService.incrementMessageCount();
                log.debug("消息已缓存到Redis");
            } catch (Exception redisError) {
                log.warn("Redis缓存失败，但不影响消息发送: {}", redisError.getMessage());
            }

            // 构造转发消息
            Map<String, Object> forwardMessage = Map.of(
                    "type", "chat",
                    "messageId", savedMessage.getId(),
                    "fromUserId", fromUserId,
                    "toUserId", toUserId,
                    "content", content,
                    "sendTime", savedMessage.getSendTime().toString()
            );

            // 发送给接收者
            ChannelHandlerContext toChannel = USER_CHANNELS.get(toUserId);
            if (toChannel != null && toChannel.channel().isActive()) {
                sendMessage(toChannel, forwardMessage);
                log.info("消息已发送给接收者: {}", toUserId);
            } else {
                log.warn("接收者 {} 不在线或连接已断开", toUserId);
            }

            // 发送确认消息给发送者
            sendMessage(ctx, Map.of(
                    "type", "message_sent",
                    "messageId", savedMessage.getId(),
                    "status", "success"
            ));

            log.info("消息发送成功: {} -> {}", fromUserId, toUserId);
        } catch (Exception e) {
            log.error("发送消息失败: {} -> {}, 内容: {}, 错误: {}", fromUserId, toUserId, content, e.getMessage(), e);
            sendErrorMessage(ctx, "发送消息失败: " + e.getMessage());
        }
    }

    /**
     * 处理心跳消息
     */
    private void handleHeartbeat(ChannelHandlerContext ctx) {
        sendMessage(ctx, Map.of("type", "heartbeat", "timestamp", System.currentTimeMillis()));
    }

    /**
     * 发送消息到客户端
     */
    private void sendMessage(ChannelHandlerContext ctx, Map<String, Object> message) {
        String json = JSON.toJSONString(message);
        ctx.writeAndFlush(new TextWebSocketFrame(json));
    }

    /**
     * 发送错误消息
     */
    private void sendErrorMessage(ChannelHandlerContext ctx, String error) {
        sendMessage(ctx, Map.of("type", "error", "message", error));
    }

    /**
     * 根据Channel获取用户ID
     */
    private Long getUserIdByChannel(ChannelHandlerContext ctx) {
        return USER_CHANNELS.entrySet().stream()
                .filter(entry -> entry.getValue() == ctx)
                .map(Map.Entry::getKey)
                .findFirst()
                .orElse(null);
    }

    /**
     * 异常处理
     */
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        log.error("WebSocket异常", cause);
        ctx.close();
    }
}
